#include "dtw.hpp"
#include <iostream>
void DTW::setXList(std::vector<double>& xList){
    xList_ = xList;
}

void DTW::setYList(std::vector<double>& yList){
    yList_ = yList;
}

std::vector<std::vector<int>> DTW::getPath(){
    return path_;
}
bool DTW::search(){
    // 初始化地图
    std::vector<std::vector<int>> map; 
    for(int i = 0; i < xList_.size(); i++){
        map.push_back(std::vector<int>());
        for(int j = 0; j < yList_.size(); j++){
            // 添加约束
            if(i >= j && fabs(i - j) < 400)
                map[i].push_back(UNVISITED);
            else
                map[i].push_back(EXPLORED);
        }
    }

    // 计算代价矩阵
    std::vector<std::vector<double>> costMatrix;        
    for(int i = 0; i < xList_.size(); i++){
        costMatrix.push_back(std::vector<double>());
        for(int j = 0; j < yList_.size(); j++){
            costMatrix[i].push_back(fabs(xList_[i] - yList_[j]));
        }
    }

    // 使用Dijkstra算法搜索
    std::vector<std::vector<int>> motionList = {{1, 0},
                                                {1, 1},
                                                {0, 1}};

    std::vector<NodePtr> openList;
    auto start = std::make_shared<Node>(0, 0);
    start->cost = 0;
    openList.push_back(start);
    
    while(!openList.empty()){
        auto current = openList[0];
        std::pop_heap(openList.begin(), openList.end(), Node());
        map[current->x][current->y] = EXPLORED;
        openList.pop_back();

        // 将邻居节点加入openlist
        for(auto motion : motionList){
            auto neighbor = std::make_shared<Node>(current->x + motion[0], current->y + motion[1]);
            if(neighbor->x >= xList_.size() || neighbor->y >= yList_.size())
                continue;
            if(map[neighbor->x][neighbor->y] == UNVISITED){ // 未访问
                openList.push_back(neighbor);
                neighbor->front = current;
                neighbor->cost = current->cost + costMatrix[neighbor->x][neighbor->y];
                map[neighbor->x][neighbor->y] = EXPLORING;
            }
            else if(map[neighbor->x][neighbor->y] == EXPLORING){ // 已经在openlist中
                for(auto node : openList){
                    if(node->isEqual(*neighbor)){
                        neighbor = node;
                        break;
                    }
                }
                double cost = current->cost + costMatrix[neighbor->x][neighbor->y];
                if(neighbor->cost > cost){
                    neighbor->front = current;
                    neighbor->cost = cost;
                }
            }
        }

        // 到达终点，回溯得到路径
        if(current->x == xList_.size() - 1
            && current->y == yList_.size() - 1){
            while(current != nullptr){
                path_.push_back(std::vector<int>({current->x, current->y}));
                current = current->front;
            }
            std::reverse(path_.begin(), path_.end());
            return true;
        }
    }
    return false;
}

#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
PYBIND11_MODULE(mydtw, m){
        m.doc() = "my dtw";
        pybind11::class_<DTW>(m, "DTW")
            .def(pybind11::init())
            .def("setXList", &DTW::setXList)
            .def("setYList", &DTW::setYList)
            .def("search", &DTW::search)
            .def("getPath", &DTW::getPath);
    }